{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Text Based Inpainting Stable Diffusion\n",
    "\n",
    "Use a text prompt to generate the mask for the area to be inpainted. Currently uses the CLIPSeg model for mask generation, then calls the standard Stable Diffusion Inpainting pipeline to perform the inpainting. This script was contributed by [Dhruv Karan](https://github.com/unography) and the notebook by [Parag Ekbote](https://github.com/ParagEkbote)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: transformers in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (4.48.3)\n",
      "Requirement already satisfied: diffusers in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (0.32.2)\n",
      "Requirement already satisfied: accelerate in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (1.3.0)\n",
      "Requirement already satisfied: filelock in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (3.17.0)\n",
      "Requirement already satisfied: huggingface-hub<1.0,>=0.24.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (0.28.1)\n",
      "Requirement already satisfied: numpy>=1.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (1.26.4)\n",
      "Requirement already satisfied: packaging>=20.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (24.2)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (6.0.2)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (2024.11.6)\n",
      "Requirement already satisfied: requests in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (2.32.3)\n",
      "Requirement already satisfied: tokenizers<0.22,>=0.21 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (0.21.0)\n",
      "Requirement already satisfied: safetensors>=0.4.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (0.5.2)\n",
      "Requirement already satisfied: tqdm>=4.27 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (4.67.1)\n",
      "Requirement already satisfied: importlib-metadata in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (8.6.1)\n",
      "Requirement already satisfied: Pillow in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (11.1.0)\n",
      "Requirement already satisfied: psutil in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from accelerate) (6.1.1)\n",
      "Requirement already satisfied: torch>=2.0.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from accelerate) (2.2.1+cu121)\n",
      "Requirement already satisfied: fsspec>=2023.5.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.24.0->transformers) (2025.2.0)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.24.0->transformers) (4.12.2)\n",
      "Requirement already satisfied: sympy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (1.13.3)\n",
      "Requirement already satisfied: networkx in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (3.4.2)\n",
      "Requirement already satisfied: jinja2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (3.1.5)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (12.1.105)\n",
      "Requirement already satisfied: triton==2.2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=2.0.0->accelerate) (2.2.0)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=2.0.0->accelerate) (12.8.61)\n",
      "Requirement already satisfied: zipp>=3.20 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from importlib-metadata->diffusers) (3.21.0)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->transformers) (3.4.1)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->transformers) (3.10)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->transformers) (2.3.0)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->transformers) (2025.1.31)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from jinja2->torch>=2.0.0->accelerate) (3.0.2)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from sympy->torch>=2.0.0->accelerate) (1.3.0)\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "pip install transformers diffusers accelerate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a5f13a2e143d4a38b023bd31cdc073c0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An error occurred while trying to fetch /home/zeus/.cache/huggingface/hub/models--runwayml--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/unet: Error no file named diffusion_pytorch_model.safetensors found in directory /home/zeus/.cache/huggingface/hub/models--runwayml--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/unet.\n",
      "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.\n",
      "An error occurred while trying to fetch /home/zeus/.cache/huggingface/hub/models--runwayml--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /home/zeus/.cache/huggingface/hub/models--runwayml--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/vae.\n",
      "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.\n",
      "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/transformers/image_processing_utils.py:41: UserWarning: The following named arguments are not valid for `ViTImageProcessor.preprocess` and were ignored: 'padding'\n",
      "  return self.preprocess(images, **kwargs)\n",
      "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py:1043: FutureWarning: `callback_steps` is deprecated and will be removed in version 1.0.0. Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\n",
      "  deprecate(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8f70795c72c34096bc03df9b64bb638e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Inpainting completed. Image saved as 'inpainting_output.png'.\n"
     ]
    }
   ],
   "source": [
    "from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation\n",
    "from diffusers import DiffusionPipeline\n",
    "from PIL import Image\n",
    "import requests\n",
    "import torch\n",
    "\n",
    "# Load CLIPSeg model and processor\n",
    "processor = CLIPSegProcessor.from_pretrained(\"CIDAS/clipseg-rd64-refined\")\n",
    "model = CLIPSegForImageSegmentation.from_pretrained(\"CIDAS/clipseg-rd64-refined\").to(\"cuda\")\n",
    "\n",
    "# Load Stable Diffusion Inpainting Pipeline with custom pipeline\n",
    "pipe = DiffusionPipeline.from_pretrained(\n",
    "    \"runwayml/stable-diffusion-inpainting\",\n",
    "    custom_pipeline=\"text_inpainting\",\n",
    "    segmentation_model=model,\n",
    "    segmentation_processor=processor\n",
    ").to(\"cuda\")\n",
    "\n",
    "# Load input image\n",
    "url = \"https://github.com/timojl/clipseg/blob/master/example_image.jpg?raw=true\"\n",
    "image = Image.open(requests.get(url, stream=True).raw)\n",
    "\n",
    "# Step 1: Resize input image for CLIPSeg (224x224)\n",
    "segmentation_input = image.resize((224, 224))\n",
    "\n",
    "# Step 2: Generate segmentation mask\n",
    "text = \"a glass\"  # Object to mask\n",
    "inputs = processor(text=text, images=segmentation_input, return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "with torch.no_grad():\n",
    "    mask = model(**inputs).logits.sigmoid()  # Get segmentation mask\n",
    "\n",
    "# Resize mask back to 512x512 for SD inpainting\n",
    "mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(512, 512), mode=\"bilinear\").squeeze(0)\n",
    "\n",
    "# Step 3: Resize input image for Stable Diffusion\n",
    "image = image.resize((512, 512))\n",
    "\n",
    "# Step 4: Run inpainting with Stable Diffusion\n",
    "prompt = \"a cup\"  # The masked-out region will be replaced with this\n",
    "result = pipe(image=image, mask=mask, prompt=prompt,text=text).images[0]\n",
    "\n",
    "# Save output\n",
    "result.save(\"inpainting_output.png\")\n",
    "print(\"Inpainting completed. Image saved as 'inpainting_output.png'.\")\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
